# Distributed TensorFlow training on GCP with 2 V100 GPUs
#
# Uses MultiWorkerMirroredStrategy to distribute training across multiple nodes with SkyPilot.
#
# Usage:
#   sky launch -c myclus tf_distributed.yaml
#   sky down myclus

resources:
  cloud: gcp
  accelerators: V100:1 # Provision 1 V100 GPU per node

# Provision 2 nodes, giving us a total of 2 GPUs in the cluster
num_nodes: 2

# Copy files (train.py) from the current working directory to all nodes on the cluster
workdir: .

# Install dependencies on all nodes during the setup phase
setup:
  pip install tensorflow==2.11.0

# The following shell commands are run on all nodes at execution
run: |
  # Port to use for tensorflow. Can be any unused port.
  PORT=2222
  
  # ======== Construct TF_CONFIG ========
  # Since we are using MultiWorkerMirroredStrategy, we need to construct a TF_CONFIG
  # environment variable that contains the list of all worker IPs and the current node's rank.
  # SkyPilot provides the SKYPILOT_NODE_IPS and SKYPILOT_NODE_RANK environment variables to get this information.
  # 
  # Examples of envvars:
  #   SKYPILOT_NODE_IPS="192.168.0.1 192.168.0.2"
  #   SKYPILOT_NODE_RANK="1"
  
  python -u <<- EOF
  import json
  import os
  
  port = 2222
  node_ips_str = os.environ.get('SKYPILOT_NODE_IPS')
  node_ips = node_ips_str.strip().split('\n')
  
  node_rank = os.environ.get('SKYPILOT_NODE_RANK')
  worker_list = [f'{ip}:{port}' for ip in node_ips]
  tf_config = {
          'cluster': {
              'worker': worker_list,
          },
          'task': {
              'type': 'worker',
              'index': node_rank
          }
      }
  with open(f'/tmp/{os.environ.get("SKYPILOT_JOB_ID")}', 'w') as f:
      json.dump(tf_config, f)
  EOF
  
  # Read and set TF_CONFIG from file
  export TF_CONFIG=$(cat /tmp/$SKYPILOT_JOB_ID)
  echo $TF_CONFIG
  
  # ======== Run the training script ========
  python train.py
